package com.i72.freeway;

import javassist.*;
import lombok.extern.slf4j.Slf4j;

import java.lang.reflect.Method;
import java.util.*;

/**
 * @author jiangj
 * @version 1.0.0
 * @ClassName ClassGenerator.java
 * @Description TODO
 * @createTime 2021年12月28日 15:06:00
 */
@Slf4j
public class ClassGenerator {

    private static final Map<ClassLoader, ClassPool> classPools = new HashMap<>();

    private ClassPool classPool;
    private CtClass ctClass;
    private String              className;
    private Set<String> interfaces = null;
    private List<String> fields     = null;
    private List<String>        methods    = null;

    public ClassGenerator() {
        this.classPool = getClassPool(Thread.currentThread().getContextClassLoader());
    }

    public ClassGenerator(ClassLoader classLoader) {
        this.classPool = getClassPool(classLoader);
    }

    public Class<?> toClass() {
        try {
            // 若ctClass不为空，则将其从classPool移除，避免OutOfMemory
            if (ctClass != null) ctClass.detach();

            ctClass = classPool.makeClass(className);

            if (interfaces != null) {
                for (String api : interfaces) ctClass.addInterface(classPool.get(api));
            }

            if (fields != null) {
                for (String field : fields) ctClass.addField(CtField.make(field, ctClass));
            }

            if (methods != null) {
                for (String method : methods) ctClass.addMethod(CtNewMethod.make(method, ctClass));
            }

            return ctClass.toClass();
        } catch (Exception e) {
            log.error("generate class {} failure", ctClass != null ? ctClass.getName() : null, e);
            throw new RuntimeException(e);
        } finally {
            release();
        }
    }

    public void addInterface(Class clazz) {
        if (interfaces == null) {
            interfaces = new HashSet<>();
        }
        interfaces.add(clazz.getName());
    }

    public void addMethod(String code) {
        if (methods == null) {
            methods = new ArrayList<>();
        }
        methods.add(code);
    }

    public void addField(String code) {
        if (fields == null) {
            fields = new ArrayList<>();
        }
        fields.add(code);
    }

    public void addMethod(Method method, String body) {
        StringBuilder buffer = new StringBuilder();
        buffer.append(getModifier(method.getModifiers())).append(" ").append(ReflectUtils.getName(method.getReturnType())).append(" ").append(method.getName());
        buffer.append('(');
        Class[] parameterTypes = method.getParameterTypes();
        for (int i = 0; i < parameterTypes.length; i++) {
            if (i > 0) {
                buffer.append(',');
            }
            buffer.append(ReflectUtils.getName(parameterTypes[i]));
            buffer.append(" arg").append(i);
        }
        buffer.append(')');
        Class[] exceptionTypes = method.getExceptionTypes();
        if (exceptionTypes != null && exceptionTypes.length > 0) {
            buffer.append(" throws ");
            for (int i = 0; i < exceptionTypes.length; i++) {
                if (i > 0) buffer.append(',');
                buffer.append(ReflectUtils.getName(exceptionTypes[i]));
            }
        }
        buffer.append('{').append(body).append('}');
        addMethod(buffer.toString());
    }

    public void release() {
        if (ctClass != null) ctClass.detach();
        if (interfaces != null) interfaces.clear();
        if (fields != null) fields.clear();
        if (methods != null) methods.clear();
    }

    public void setClassName(String className) {
        this.className = className;
    }

    private String getModifier(int mod) {
        if (java.lang.reflect.Modifier.isPublic(mod)) return "public";
        if (java.lang.reflect.Modifier.isProtected(mod)) return "protected";
        if (java.lang.reflect.Modifier.isPrivate(mod)) return "private";
        return "";
    }

    private ClassPool getClassPool(ClassLoader classLoader) {
        if (classLoader == null) {
            return ClassPool.getDefault();
        }

        if (!classPools.containsKey(classLoader)) {
            synchronized (classPools) {
                if (!classPools.containsKey(classLoader)) {
                    ClassPool pool = ClassPool.getDefault();
                    pool.appendClassPath(new LoaderClassPath(classLoader));
                    classPools.put(classLoader, pool);
                }
            }
        }

        return classPools.get(classLoader);
    }

}
